{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning to compute a product\n",
    "\n",
    "Unlike the communication channel and the element-wise square,\n",
    "the product is a nonlinear function on multiple inputs.\n",
    "This represents a difficult case for learning rules\n",
    "that aim to generalize a function given many\n",
    "input-output example pairs.\n",
    "However, using the same type of network structure\n",
    "as in the communication channel and square cases,\n",
    "we can learn to compute the product of two dimensions\n",
    "with the `nengo.PES` learning rule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "import nengo\n",
    "from nengo.processes import WhiteSignal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model\n",
    "\n",
    "Like previous examples,\n",
    "the network consists of `pre`, `post`, and `error` ensembles.\n",
    "We'll use two-dimensional white noise input and attempt to learn the product\n",
    "using the actual product to compute the error signal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network()\n",
    "with model:\n",
    "    # -- input and pre popluation\n",
    "    inp = nengo.Node(WhiteSignal(60, high=5), size_out=2)\n",
    "    pre = nengo.Ensemble(120, dimensions=2)\n",
    "    nengo.Connection(inp, pre)\n",
    "\n",
    "    # -- post population\n",
    "    post = nengo.Ensemble(60, dimensions=1)\n",
    "\n",
    "    # -- reference population, containing the actual product\n",
    "    product = nengo.Ensemble(60, dimensions=1)\n",
    "    nengo.Connection(inp, product, function=lambda x: x[0] * x[1], synapse=None)\n",
    "\n",
    "    # -- error population\n",
    "    error = nengo.Ensemble(60, dimensions=1)\n",
    "    nengo.Connection(post, error)\n",
    "    nengo.Connection(product, error, transform=-1)\n",
    "\n",
    "    # -- learning connection\n",
    "    conn = nengo.Connection(\n",
    "        pre,\n",
    "        post,\n",
    "        function=lambda x: np.random.random(1),\n",
    "        learning_rule_type=nengo.PES(),\n",
    "    )\n",
    "    nengo.Connection(error, conn.learning_rule)\n",
    "\n",
    "    # -- inhibit error after 40 seconds\n",
    "    inhib = nengo.Node(lambda t: 2.0 if t > 40.0 else 0.0)\n",
    "    nengo.Connection(inhib, error.neurons, transform=[[-1]] * error.n_neurons)\n",
    "\n",
    "    # -- probes\n",
    "    product_p = nengo.Probe(product, synapse=0.01)\n",
    "    pre_p = nengo.Probe(pre, synapse=0.01)\n",
    "    post_p = nengo.Probe(post, synapse=0.01)\n",
    "    error_p = nengo.Probe(error, synapse=0.03)\n",
    "\n",
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[pre_p], c=\"b\")\n",
    "plt.legend((\"Pre decoding\",), loc=\"best\")\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[product_p], c=\"k\", label=\"Actual product\")\n",
    "plt.plot(sim.trange(), sim.data[post_p], c=\"r\", label=\"Post decoding\")\n",
    "plt.legend(loc=\"best\")\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange(), sim.data[error_p], c=\"b\")\n",
    "plt.ylim(-1, 1)\n",
    "plt.legend((\"Error\",), loc=\"best\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examine the initial output\n",
    "\n",
    "Let's zoom in on the network at the beginning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange()[:2000], sim.data[pre_p][:2000], c=\"b\")\n",
    "plt.legend((\"Pre decoding\",), loc=\"best\")\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(sim.trange()[:2000], sim.data[product_p][:2000], c=\"k\", label=\"Actual product\")\n",
    "plt.plot(sim.trange()[:2000], sim.data[post_p][:2000], c=\"r\", label=\"Post decoding\")\n",
    "plt.legend(loc=\"best\")\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange()[:2000], sim.data[error_p][:2000], c=\"b\")\n",
    "plt.ylim(-1, 1)\n",
    "plt.legend((\"Error\",), loc=\"best\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above plot shows that when the network is initialized,\n",
    "it is not able to compute the product.\n",
    "The error is quite large."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examine the final output\n",
    "\n",
    "After the network has run for a while, and learning has occurred,\n",
    "the output is quite different:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(3, 1, 1)\n",
    "plt.plot(sim.trange()[38000:42000], sim.data[pre_p][38000:42000], c=\"b\")\n",
    "plt.legend((\"Pre decoding\",), loc=\"best\")\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(\n",
    "    sim.trange()[38000:42000],\n",
    "    sim.data[product_p][38000:42000],\n",
    "    c=\"k\",\n",
    "    label=\"Actual product\",\n",
    ")\n",
    "plt.plot(\n",
    "    sim.trange()[38000:42000],\n",
    "    sim.data[post_p][38000:42000],\n",
    "    c=\"r\",\n",
    "    label=\"Post decoding\",\n",
    ")\n",
    "plt.legend(loc=\"best\")\n",
    "plt.subplot(3, 1, 3)\n",
    "plt.plot(sim.trange()[38000:42000], sim.data[error_p][38000:42000], c=\"b\")\n",
    "plt.ylim(-1, 1)\n",
    "plt.legend((\"Error\",), loc=\"best\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that it has learned a decent approximation of the product,\n",
    "but it's not perfect -- typically,\n",
    "it's not as good as the offline optimization.\n",
    "The reason for this is that we've given it white noise input,\n",
    "which has a mean of 0; since this happens in both dimensions,\n",
    "we'll see a lot of examples of inputs and outputs near 0.\n",
    "In other words, we've oversampled a certain part of the\n",
    "vector space, and overlearned decoders that do well in\n",
    "that part of the space. If we want to do better in other\n",
    "parts of the space, we would need to construct an input\n",
    "signal that evenly samples the space."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
